
import pandas as pd
import argparse
import numpy as np
import torch
import os
import datetime
import time
from matplotlib import pyplot as plt
from data.datautils import Dataset_ETT_hour,batch_x_ffts
from utils.util import EarlyStopping,_logger
from torch.utils.data import DataLoader
from model.encoder import Time_Frequence_Mul
from model.decoder import linear_Decoder,Attention_Decoder
from trainer import Trainer
from Config import Configs


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',default="d:\\data\\etth",type = str,help="数据的根路径")
    parser.add_argument('--data_name',default='ETTh1.csv',type = str, help="数据集名字")
    parser.add_argument('--save_path',default='./resModel',type=str,help="模型存储的地方")
    parser.add_argument('--experiment_description',default='pretrain',type= str,help='帮助记住这是干啥的')
    parser.add_argument('--modelname',default='linear_5_100_50.pth',type=str,help=' name of saved model.')
    parser.add_argument('--seed',default=3678,type = int,help="random seed")
    parser.add_argument('--lamubda',default=1,type=int,help='regulational size. ')
    parser.add_argument('--patience',default=30,type=int,help='early stopping.')
    parser.add_argument('--logs_save_dir', default='../experiments_logs', type=str,help='saving directory')
    parser.add_argument('--epoches',default=400,type=int,help='the epoches of learning. ')
    parser.add_argument('--training_mode', default='pre_train', type=str, help='pre_train, training')
    parser.add_argument('--run_description', default='run1', type=str,help='Experiment Description')
    parser.add_argument('--size',default=[96,24,24],help='size for learning , training, testing')
    # parser.add_argument('--device',default=0,type=int,help ='training device ,cpu or gpu. ')
    args = parser.parse_args()
    configs = Configs()
    #      
    SEED = args.seed
    experiment_log_dir = os.path.join("chec",args.logs_save_dir,args.experiment_description, args.training_mode + f"_seed_{SEED}")

    log_file_name = os.path.join(experiment_log_dir, f"logs_{datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log")
    logger = _logger(log_file_name)
    logger.debug(f'Data_name: {args.data_name}')
    logger.debug(f'Mode:    {args.training_mode}')
    logger.debug("Data loaded ...")

    # 设置相关的随机数种子

    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED) 

    # 准备数据
    Data = Dataset_ETT_hour
    train_data_set = Data(args.data_path ,flag = 'train', size = configs.SIZE)
    test_data_set = Data(args.data_path ,flag = 'test', size = configs.SIZE)
    valid_data_set = Data(args.data_path,flag = 'val', size = configs.SIZE)

    train_dl = DataLoader(train_data_set,batch_size = configs.BATCH_SIZE,shuffle = configs.SHUFFLE_FLAG,drop_last = configs.DROP_LAST)
    test_dl = DataLoader(test_data_set,batch_size = configs.BATCH_SIZE,shuffle = configs.SHUFFLE_FLAG,drop_last = configs.DROP_LAST)
    valid_dl = DataLoader(valid_data_set,batch_size = configs.BATCH_SIZE,shuffle = configs.SHUFFLE_FLAG,drop_last = configs.DROP_LAST)

    #earlyStopping:
    early_stopping = EarlyStopping(patience=3, verbose=True)

    # Load Model
    model = Time_Frequence_Mul(configs.INPUT_DIMS,configs.OUTPUT_DIMS,configs.HIDDEN_DIMS,configs.lr,configs.BATCH_SIZE,configs.DEVICE,configs.VARS,configs).to(configs.DEVICE)
    optimizer = torch.optim.Adam(model.parameters(),lr = configs.lr,weight_decay=3e-4)
    decoder = linear_Decoder(configs.SEQ_LEN,configs.LABEL_LEN + configs.PRED_LEN,configs.HIDDEN_DIMS,configs.lr,configs.BATCH_SIZE,configs.DEVICE,configs.VARS).to(configs.DEVICE)
    time_now = time.time()

    Trainer(model,optimizer,train_dl,test_dl,valid_dl,configs,args.training_mode,logger,decoder,experiment_log_dir)
    
    logger.debug(f"Training time is : {datetime.datetime.now()-time_now}")




